import csv
import gzip

import torch
from bokeh.util.serialization import NP_EPOCH
from torch.utils.data import Dataset, DataLoader

class NameDataset(Dataset):
    def __init__(self,is_train = True):
        filename = './../data/name2country/names_train.csv.gz' if  is_train else './../data/name2country/names_test.csv.gz'
        with gzip.open(filename, 'rt') as f:
            reader = csv.reader(f)
            rows = list(reader)

            self.names = [row[0] for row in rows]
            self.len = len(rows)
            self.countrys = [row[1] for row in rows]

            self.country_list = list(sorted(set(self.countrys)))
            self.country_Dict = self.get_Country_Dict()
            self.country_num = len(self.country_list)


    def __getitem__(self, index):
        name , country = self.names[index], self.country_Dict[self.countrys[index]]
        return name, country

    def __len__(self):
        return self.len


    def get_Country_Dict(self):
        country_dict = {}
        for i ,country in enumerate(self.country_list):
            country_dict[country] = i
        return country_dict

    def id2country(self, index):
        return self.country_list[index]

    def get_Country_num(self):
        return self.country_num


import time
import csv
import torch
from torch.utils.data import DataLoader , Dataset
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import gzip
import numpy as np
import math
import matplotlib.pyplot as plt

HIDDEN_SIZE = 100
BATCH_SIZE = 256
N_LAYERS = 2
N_CHARS = 128
NP_EPOCHS = 100
USE_GPU = False

def name2ascii(name):
    arr = [ord(c) for c in name]
    return arr , len(arr)

def create_tensor(tensor):
    if USE_GPU:
        tensor = tensor.cuda()
    return tensor

def timesince(since):
    now = time.time()
    s = now - since
    m = math.floor(s/60)
    s = s - m * 60
    return '%dmin %ds' %(m, s)


train_dataloader = DataLoader(NameDataset(True), batch_size = BATCH_SIZE, shuffle = True,num_workers = 2)
test_dataloader = DataLoader(NameDataset(False), batch_size = BATCH_SIZE, shuffle = False,num_workers = 2)

N_COUNTRY = NameDataset(False).get_Country_num()







names = NameDataset(is_train = True)



































